package com.shujia.core

import com.shujia.core.Demo10Join.Student
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD

import java.sql.{Connection, DriverManager, PreparedStatement, ResultSet}
import scala.collection.mutable.ListBuffer

object Demo18MapPartitions {
  def main(args: Array[String]): Unit = {
    println("============Spark程序开始运行============")
    // RDD是由一系列分区组成的
    // Spark会提供最佳计算位置 移动计算 不移动数据
    /**
     * 可以将Spark代码分为两个部分
     * Driver端执行的
     * Task中执行的
     */
    val conf: SparkConf = new SparkConf()
    conf.setAppName("Demo18MapPartitions")
    conf.setMaster("local")

    val sc: SparkContext = new SparkContext(conf)

    // 读取学生数据及分数数据 并将每一行数据转换成样例类对象
    //    val stuRDD: RDD[Student] = sc.parallelize(List(
    //      Student("1500100001", "zs", 20, "男", "文科一班"),
    //      Student("1500100002", "ls", 21, "女", "文科三班"),
    //      Student("1500100003", "ww", 23, "男", "文科二班"),
    //      Student("1500100004", "ll", 22, "男", "文科一班")
    //    ))

    val stuRDD: RDD[Student] = sc
      .textFile("Spark/data/students.txt")
      .map(line => {
        val splits: Array[String] = line.split(",")
        val id: String = splits(0)
        val name: String = splits(1)
        val age: Int = splits(2).toInt
        val gender: String = splits(3)
        val clazz: String = splits(4)
        Student(id, name, age, gender, clazz)
      })



    // 遍历RDD中的每个学生数据 取id 并作为PreparedStatement的参数
    stuRDD

      /**
       * flatMap中的逻辑 会被封装成Task 发送到Executor中执行
       * 在flatMap中传入的逻辑需要能够进行序列化
       * 在这里由于MySQL的连接是不能被序列化的
       */
      .flatMap(stu => {
        // 算子外部的代码会在Driver执行 算子内部的代码会被以Task形式发送到Executor中执行
        // 由于MySQL的连接不能被序列化的 所以需要将建立连接的过程放每一个Task中
        // 如果直接使用map 或者是 flatMap 或者 foreach 去遍历RDD中的每条数据
        // 那么会造成 每条数据都需要建立一次连接 建立和销毁连接的过程是非常耗时的
        // 从远程数据库MySQL 获取分数表score数据 并实现关联
        // 建立MySQL连接
        val conn: Connection = DriverManager.getConnection("jdbc:mysql://rm-bp1h7v927zia3t8iwho.mysql.rds.aliyuncs.com:3306/stu016?useSSL=false", "shujia016", "123456")
        // 创建Statement
        val pSt: PreparedStatement = conn.prepareStatement("select course_id,score from score where student_id = ?")

        val id: String = stu.id
        // 设置SQL参数
        pSt.setInt(1, id.toInt)
        // 执行查询语句
        val rs: ResultSet = pSt.executeQuery()
        val stuListBF: ListBuffer[String] = ListBuffer[String]()

        // 将每个学生的6门成绩加入ListBuffer中 并使用flatMap展开
        while (rs.next()) {
          val score_id: Int = rs.getInt("course_id")
          val score: Int = rs.getInt("score")
          stuListBF.append(s"$id,${stu.name},$score_id,$score")
        }
        // 关闭连接
        pSt.close()
        conn.close()

        stuListBF
      })
    //      .foreach(println)

    // 使用mapPartitions 或者 foreachPatitions 代替map/flatMap/foreach

    stuRDD
      .mapPartitions(stuIter => {
        // mapPartitions 会对每个分区进行遍历
        // 每个分区 即 每个Task任务 会建立一次连接 避免每条数据建立连接
        // 从远程数据库MySQL 获取分数表score数据 并实现关联
        // 建立MySQL连接
        println("创建MySQL连接")
        val conn: Connection = DriverManager.getConnection("jdbc:mysql://rm-bp1h7v927zia3t8iwho.mysql.rds.aliyuncs.com:3306/stu016?useSSL=false", "shujia016", "123456")
        // 创建Statement
        val pSt: PreparedStatement = conn.prepareStatement("select course_id,score from score where student_id = ?")

        val stuScoreIter: Iterator[String] = stuIter
          .flatMap(stu => {
            val id: String = stu.id
            // 设置SQL参数
            pSt.setInt(1, id.toInt)
            // 执行查询语句
            val rs: ResultSet = pSt.executeQuery()
            val stuListBF: ListBuffer[String] = ListBuffer[String]()

            // 将每个学生的6门成绩加入ListBuffer中 并使用flatMap展开
            while (rs.next()) {
              val score_id: Int = rs.getInt("course_id")
              val score: Int = rs.getInt("score")
              stuListBF.append(s"$id,${stu.name},$score_id,$score")
            }
            stuListBF
          })
        stuScoreIter
      })
    //      .foreach(println)

    val intRDD: RDD[Int] = sc.parallelize(List(1, 2, 3, 4, 5), 2)
    intRDD
      .mapPartitionsWithIndex((index, iter) => {
        println(s"当前遍历的分区为:$index")
        iter
          .map(i => i * i)
      }).foreach(println)


    println("============Spark程序运行结束============")


  }

}
